# MIT License

# Copyright (c) 2024 dechin

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# cythonize -i -f wrapper.pyx
import os
import site
from pathlib import Path
cimport cython

site_path_ = site.getsitepackages()[0]
user_site_path_ = site.USER_SITE
site_path = str(Path(site_path_).parent.parent.parent)
user_site_path = str(Path(user_site_path_).parent.parent.parent)

if os.environ['LD_LIBRARY_PATH']:
    os.environ['LD_LIBRARY_PATH'] += os.pathsep + site_path_ + os.pathsep + site_path
else:
    os.environ['LD_LIBRARY_PATH'] = site_path_ + os.pathsep + site_path
os.environ['LD_LIBRARY_PATH'] += os.pathsep + user_site_path_ + os.pathsep + user_site_path

import numpy as np
cimport numpy as np

cdef float PI = 3.14159265359
cdef float SQRT2PI = 2.506628274631
cdef float SQRT2PI3 = 15.74961
cdef float kT = 0.596128107

cdef extern from "<dlfcn.h>" nogil:
    void *dlopen(const char *, int)
    char *dlerror()
    void *dlsym(void *, const char *)
    int dlclose(void *)
    enum:
        RTLD_LAZY

ctypedef struct CRD:
    float x, y, z

ctypedef struct PATH:
    CRD crds

ctypedef int (*GetWeightFunc)(int, float*, float, float*)
ctypedef int (*GetDistFunc)(int, CRD*, PATH*, float*)
ctypedef int (*GaussDistFunc)(int, CRD*, PATH*, float*, float*)
ctypedef PATH* (*StickCv)(int CV_LENGTH, PATH* cv)
ctypedef int (*ReleaseCv)(PATH*)
ctypedef int (*FastGaussDistFunc)(int, CRD*, PATH*, float*, float*)

cufes_path = site_path + "/cyfes/libcufes.1.so"
if not os.path.exists(cufes_path):
    cufes_path = site_path_ + "/cyfes/libcufes.1.so"
if not os.path.exists(cufes_path):
    cufes_path = user_site_path + "/cyfes/libcufes.1.so"
if not os.path.exists(cufes_path):
    cufes_path = user_site_path_ + "/cyfes/libcufes.1.so"
if not os.path.exists(cufes_path):
    raise ValueError('No dynamic link file libcufes found!')

cdef void* handle = dlopen(cufes_path.encode('utf-8'), RTLD_LAZY)

cpdef float[:] get_weight(float[:] bias):
    cdef:
        GetWeightFunc GetWeight
        int success
        int CV_LENGTH = bias.shape[0]
        float shift = 0.0
        float[:] weight = np.zeros((CV_LENGTH, ), dtype=np.float32)
    GetWeight = <GetWeightFunc>dlsym(handle, "GetWeight")
    success = GetWeight(CV_LENGTH, &bias[0], shift, &weight[0])
    return weight

cpdef float[:] get_dis(float[:] crd, float[:, :] cv):
    cdef:
        GetDistFunc GetDist
        int success
        int CV_LENGTH = cv.shape[0]
        float[:] dis = np.zeros((CV_LENGTH, ), dtype=np.float32)
    GetDist = <GetDistFunc>dlsym(handle, 'GetDist')
    success = GetDist(CV_LENGTH, <CRD*>&crd[0], <PATH*>&cv[0][0], &dis[0])
    return dis

cpdef float[:, :] batch_get_dis(float[:, :] crd, float[:, :] cv):
    cdef:
        GetDistFunc GaussGetDist
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        float[:, :] dis = np.zeros((CRD_LENGTH, CV_LENGTH), dtype=np.float32)
    GaussGetDist = <GetDistFunc>dlsym(handle, 'GaussGetDist')
    for i in range(CRD_LENGTH):
        success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], <PATH*>&cv[0][0], &dis[i][0])
    return dis

cpdef float[:] PathFES(float[:, :] crd, float[:, :] cv, float[:] bw, float[:] bias):
    cdef:
        float volume, res_min
        float[:] weight
        GaussDistFunc GaussGetDist
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        float[:] height = np.zeros((CV_LENGTH, ), dtype=np.float32)
        float[:, :] dis = np.zeros((CRD_LENGTH, CV_LENGTH), dtype=np.float32)
        float[:] res = np.zeros((CRD_LENGTH, ), dtype=np.float32)
    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    weight = get_weight(bias)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    GaussGetDist = <GaussDistFunc>dlsym(handle, 'GaussGetDistHeight')
    for i in range(CRD_LENGTH):
        success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], <PATH*>&cv[0][0], &dis[i][0], &height[0])
    for i in range(CRD_LENGTH):
        res[i] = -kT * np.log(np.sum(dis[i]) / CV_LENGTH)
    res_min = np.min(res)
    for i in range(CRD_LENGTH):
        res[i] -= res_min
    return res

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef float[:] FastPathFES(float[:, :] crd, float[:, :] cv, float[:] bw, float[:] bias):
    cdef:
        float volume, res_min
        float[:] weight
        FastGaussDistFunc GaussGetDist = <FastGaussDistFunc>dlsym(handle, 'GaussGetDistHeightDevice')
        StickCv to_cuda = <StickCv>dlsym(handle, 'StickCv')
        ReleaseCv free_cuda = <ReleaseCv>dlsym(handle, 'ReleaseCv')
        PATH* cv_device
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        float[:] height = np.zeros((CV_LENGTH, ), dtype=np.float32)
        float[:] dis = np.zeros((CV_LENGTH, ), dtype=np.float32)
        float[:] res = np.zeros((CRD_LENGTH, ), dtype=np.float32)

    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    weight = get_weight(bias)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    cv_device = to_cuda(CV_LENGTH, <PATH*>&cv[0][0])
    for i in range(CRD_LENGTH):
        success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0])
        res[i] = -kT * np.log(np.sum(dis, axis=-1) / CV_LENGTH)
    res -= np.min(res)
    success = free_cuda(cv_device)
    return res

# while not True:
#     dlclose(handle)
